'''
Class for hosting sqlcoder LLM
Configures LLM settings, model used, tokenizer and schema / intructions
Executes sql queries against a postgres database
Returns results to the calling function
'''
# AI and standard imports
from pydantic import BaseModel
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline, BitsAndBytesConfig
import torch
import os
import sys
import json
import yaml
from pathlib import Path
from typing import Dict, Any, List, Optional
import time


# databse imports
from sqlalchemy import create_engine, text, inspect
from sqlalchemy.exc import SQLAlchemyError
import pandas as pd

''' Pydantic models for request and response '''
class Query(BaseModel):
    question: str

class Response(BaseModel):
    sql_query: str
    model_info: Dict[str, Any]

class QueryWithExecution(BaseModel):
    question: str
    max_rows: int = 100 
    debug: bool = True
    indices: Optional[List[int]] = None  # Indices for vector search, if applicable

class DatabaseResponse(BaseModel):
    sql_query: str
    executed: bool
    results: Optional[List[Dict[str, Any]]] = None
    row_count: Optional[int] = None
    error: Optional[str] = None
    execution_time_ms: Optional[float] = None
    model_info: Dict[str, Any]

# Model configuration
MODEL_PATH = "/scratch/dhoward/sqlcoder-2"  
MAX_TOKENS = 500
USE_GPU = torch.cuda.is_available()
USE_4BIT_QUANTIZATION = False 
USE_8BIT_QUANTIZATION = False
USE_NO_QUANTIZATION = False 
BATCH_SIZE = 1


# Database on myadvisor server
DATABASE_CONFIG = {
    'host': '137.158.160.232',
    'port': 5432,
    'dbname': 'myadvisor',
    'user': 'postgres',
    'password': 'myadvisor' 
}


# Model loading parameters
MODEL_SETTINGS = {
    'low_cpu_mem_usage': True,
    'torch_dtype': torch.float16, # float 16 used for 4bit quant
    'device_map': 'auto',
    'max_memory': {
        0: '9GiB',  
        1: '9GiB',  
        'cpu': '8GiB'
    },
    'trust_remote_code': False,
    'use_fast': True,
    'use_cache': True
}

DEBUG_QUERIES = True
DEBUG_RESULTS = True

def get_model_settings(model_name: str) -> Dict[str, Any]:
    """Get model-specific settings"""
    default_settings = {
        'top_p': 0.95,
        'top_k': 40,
        'repetition_penalty': 1.1,
        'max_new_tokens': MAX_TOKENS,
        'do_sample': False, # False means 0 temperature
        'temperature': 0.01,  
        'pad_token_id': None,  
        'eos_token_id': None   
    }
    return default_settings

def is_safe_query(sql_query: str) -> bool:
    '''Check if query is safe (only SELECT statements allowed)'''
    dangerous_keywords = ['DROP', 'DELETE', 'UPDATE', 'INSERT', 'ALTER', 'CREATE', 'TRUNCATE']
    return not any(keyword in sql_query.upper() for keyword in dangerous_keywords)

def debug_print_query(sql_query: str, context: str = ""):
    '''Print SQL query for debugging purposes'''
    if DEBUG_QUERIES:
        print(f"DEBUGGING SQL QUERY {context}")
        print(f"Generated SQL:")
        print(f"{sql_query}")
        print(f"{'='*60}\n")

def debug_print_results(results: List[Dict], row_count: int, execution_time: float):
    """Print query results for debugging"""
    if DEBUG_RESULTS:
        print(f"QUERY EXECUTION RESULTS")
        print(f"Rows returned: {row_count}")
        print(f"Execution time: {execution_time:.2f}ms")
        if results and len(results) > 0:
            print(f"Sample result (first row):")
            for key, value in results[0].items():  # Show all columns
                print(f"  {key}: {value}")
        print(f"{'='*60}\n")

try:
    if not os.path.exists(MODEL_PATH):
        raise FileNotFoundError(f"Model path '{MODEL_PATH}' does not exist.")
    print("Loading model settings...")
    model_settings = get_model_settings(MODEL_PATH)
    quantization_config = None
    if USE_4BIT_QUANTIZATION:
        quantization_config = BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_quant_type="nf4",
            bnb_4bit_compute_dtype=torch.float16,
            bnb_4bit_use_double_quant=False,
            llm_int8_enable_fp32_cpu_offload=True,
            llm_int8_threshold=6.0
        )
        print("4-bit quantization enabled")
    elif USE_8BIT_QUANTIZATION:
        print("8-bit quantization enabled")
        quantization_config = BitsAndBytesConfig(
            load_in_8bit=True,
            llm_int8_threshold=6.0
        )
    else: 
        print("No type of quantization selected, using FULL MODEL POWWWEEEEERRRRR!!!")

    print("Loading tokenizer...")
    tokenizer = AutoTokenizer.from_pretrained( # using standard huggingface tokenizer
        MODEL_PATH,
        use_fast=MODEL_SETTINGS.get('use_fast', True),
        trust_remote_code=MODEL_SETTINGS.get('trust_remote_code', False)
    )
    print("Success, attemping to load model\n")
    # Set pad token if it doesn't exist
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    model_kwargs = {
        'torch_dtype': MODEL_SETTINGS['torch_dtype'],
        'device_map': MODEL_SETTINGS['device_map'],
        'max_memory': MODEL_SETTINGS['max_memory'],
        'low_cpu_mem_usage': MODEL_SETTINGS['low_cpu_mem_usage'],
        'trust_remote_code': MODEL_SETTINGS['trust_remote_code']
    }
    if quantization_config is not None:
        model_kwargs['quantization_config'] = quantization_config
    model = AutoModelForCausalLM.from_pretrained(MODEL_PATH, **model_kwargs)
    # Update model settings with tokenizer info
    model_settings['pad_token_id'] = tokenizer.pad_token_id
    model_settings['eos_token_id'] = tokenizer.eos_token_id
    # Create pipeline with optimized settings
    print("Creating pipeline...")
    generator = pipeline(
        "text-generation", # using text generation pipeline
        model=model, 
        tokenizer=tokenizer,
        torch_dtype=torch.float16,
        device_map=MODEL_SETTINGS['device_map']
    )
    
    device_str = "distributed across CPU & GPU" if MODEL_SETTINGS['device_map'] == "auto" else MODEL_SETTINGS['device_map']
    print(f"Model loaded successfully using {device_str}")
    # print(f"Model settings: {model_settings}") # for testing settings loaded
    
except Exception as e:
    print(f"Error initializing model, the following error was returned: {str(e)}")
    sys.exit(1)

DATABASE_URL = f"postgresql://{DATABASE_CONFIG['user']}:{DATABASE_CONFIG['password']}@{DATABASE_CONFIG['host']}:{DATABASE_CONFIG['port']}/{DATABASE_CONFIG['dbname']}"
# Connect to database using SQLAlchemy
try:
    engine = create_engine(
        DATABASE_URL,
        pool_size=5,
        max_overflow=10,
        pool_pre_ping=True,
        echo=False  # Set to True for SQLAlchemy query logging
    )
    print("Database connection established successfully")
except Exception as e:
    print(f"Error connecting to database: {e}")
    engine = None

def get_schema():
    '''
    Dynamically fetches the database schema using SQLAlchemy
    Returns a formatted string of table names and columns
    '''
    if engine is None:
        return {"error": "Database not connected, check database status."}
    inspector = inspect(engine)
    schema = []
    for table_name in inspector.get_table_names():
        schema.append(f"Table: {table_name}")
        columns = inspector.get_columns(table_name)
        for column in columns:
            schema.append(f"  - {column['name']} ({column['type']})")
    return {"schema": "\n".join(schema)}

def get_instructions() -> str:
    schema = get_schema()
    return f"""

        You are an expert in translating natural language into PostgreSQL queries for a university database.

        SCHEMA:
        {schema}

        CORE RULES:
        1. Return ONLY the SQL query - no explanations, no markdown formatting
        2. Use INNER JOINs unless specifically asked for optional/missing data
        3. Never hardcode IDs - always join through descriptive names

        COURSE CODE FORMAT:
        4. Course codes follow the format: [3 LETTERS][4 NUMBERS][SEMESTER LETTER]
           - Examples: CSC1015F, MAM1000W, STA2007S, PHY1004F
           - 3 letters: Department code (CSC, MAM, STA, PHY, etc.)
           - 4 numbers: Course number (1015, 1000, 2007, etc.)
           - Semester letter: F (First semester), S (Second semester), W (Whole year), Z (Summer), P (Preliminary), etc.

        CRITICAL COURSE CODE MATCHING RULES:
        5. Course codes in the database are stored in UPPERCASE format
        6. When users provide course codes WITHOUT semester letters (like "CSC1015", "mam1000"):
           - Always use: WHERE course_code LIKE 'CSC1015%' OR course_code LIKE 'MAM1000%'
        7. When users provide FULL course codes WITH semester letters (like "csc1015f" or "MAM1000W"):
           - Convert to uppercase in the query
           - Use exact match: WHERE course_code = 'CSC1015F' OR course_code = 'MAM1000W'
        8. Always use LIKE with % when semester letter is missing to catch all semester variants
        9. If user asks about a course, always return the course outline along with other information

        DEGREE-SPECIFIC COURSE QUESTIONS:
        10. **IMPORTANT**: If the question is asking about courses for a specific degree, program, major, or academic pathway (e.g., "What courses do I need for Computer Science degree?", "First year BSc courses", "Required courses for my major"), do NOT generate any SQL query.
        11. For such questions, return exactly: "HANDBOOK_REDIRECT"
        12. This applies to questions about course structures, study plans, degree requirements, and program-specific curricula.
        EXAMPLES:
        - User asks: "WHAT IS THE COURSE NAME OF STA5090?"
          SQL: SELECT course_name FROM course WHERE course_code LIKE 'STA5090%';

        - User asks: "WHAT COURSES DO I NEED IF IM STUDYING BUSINESS SCIENCE COMPUTER SCIENCE?"
          Response: HANDBOOK_REDIRECT
        
        - User asks: "WHAT IS THE COURSE NAME OF CSC1015F?"
          SQL: SELECT course_name FROM course WHERE course_code = 'CSC1015F';
        
        - User asks: "SHOW ME ALL CSC2001 COURSES"
          SQL: SELECT * FROM course WHERE course_code LIKE 'CSC2001%';

        """

# Health check endpoint
def health_check():
    return {
        "status": "healthy",
        "model_loaded": True,
        "device": "cuda" if USE_GPU else "cpu"
    }
    
def generate_and_execute_query(query: QueryWithExecution): 
    ''' 
    Generate sql query and execute it against the database.
    Returns a DatabaseResponse object with the results.
    Also returns the query question for chatbot llm referece.
     '''
    start_time = time.time() # time execution
    try:
        instructions = get_instructions()
        prompt = f"{instructions}\n\nQuestion: {query.question.upper()}\nSQL:"
        result = generator(
            prompt,
            max_new_tokens=MAX_TOKENS,
            do_sample=True,
            temperature=0.01,  # False means no temperature
            pad_token_id=tokenizer.eos_token_id
        )
        generated_text = result[0]["generated_text"]
        sql_query = generated_text.split("SQL:")[-1].strip()
        # Remove whitespace
        sql_query = sql_query.strip()
        debug_print_query(sql_query, "(Generate and Execute)" )
        response_data= {
            "sql_query": sql_query,
            "executed": False,
            "model_info": {"model": "{MODEL_PATH}"}
        }
        if engine is not None:
            try:
                if not is_safe_query(sql_query):
                    raise ValueError("Only select queries are allowed for security")
                with engine.connect() as database_conn:
                    print(f"Attempting to execute query: {sql_query}")
                    # execute the query
                    db_result = database_conn.execute(text(sql_query))
                    # Fetch result with row limit
                    rows = db_result.fetchmany(query.max_rows)
                    columns = db_result.keys()
                    # Convert to a list of dictionaries
                    results = [dict(zip(columns, row)) for row in rows]
                    execution_time = (time.time() - start_time) * 1000 # Time of execution in milliseconds
                    # Debugging
                    debug_print_results(results, len(results), execution_time)
                    response_data.update({
                        "executed": True,
                        "user_question": query.question,
                        "results": results,
                        "row_count": len(results),
                        "execution_time_ms": execution_time
                    })
            except SQLAlchemyError as e:
                execution_time = (time.time() - start_time)* 1000
                error_msg = f"Database error: {str(e)}"
                response_data.update({
                    "executed": True,
                    "error": error_msg,
                    "execution_time": execution_time
                })
            except ValueError as e:
                error_msg = str(e)
                print(error_msg)
                response_data.update({
                    "executed": False,
                    "error": error_msg
                })
        elif not engine:
            response_data["error"] = "Database not connected, check database status."
        return DatabaseResponse(**response_data)
    except Exception as e:
        print(f"Error processing query: {str(e)}")


def check_database_status():
    """Check database connection status"""
    if engine is None:
        return {"status": "disconnected", "error": "Database engine not initialized"}
    try:
        with engine.connect() as conn:
            result = conn.execute(text("SELECT 1"))
            result.fetchone()
        return {
            "status": "connected",
            "database": DATABASE_CONFIG['dbname'],
            "host": DATABASE_CONFIG['host'],
            "port": DATABASE_CONFIG['port']
        }
    except Exception as e:
        return {"status": "error", "error": str(e)}
    

